import torch
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaForSequenceClassification, RobertaTokenizerFast, AdamW

import pandas as pd
from tqdm.auto import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt

import random
import numpy as np
from sklearn.utils import check_random_state

import gc
import copy

# Nastavenie semena pre PyTorch
torch.manual_seed(42)

# Nastavenie semena pre numpy
np.random.seed(42)

# Nastavenie semena pre random modul
random.seed(42)

# Vytvorenie objektu RandomState pre scikit-learn
rng = check_random_state(42)

import pickle
with open("encodings/encodings_sentiment_purebpe", 'rb') as f:
    encodings = pickle.load(f)
    
    
# Testovanie na málo náhodných prvkoch  #_________________________________________________________________________________
import random
import torch
# Počet prvkov, ktoré chcete vybrať
# n = 128 # Zmeňte podľa potreby
# # Získanie dĺžky pôvodnej množiny
# total_samples = len(encodings['input_ids'])
# # Vygenerovanie náhodných indexov
# random_indices = random.sample(range(total_samples), n)
# # Vybranie príslušných prvkov pomocou vybraných indexov a klonovanie
# encodings = {
#     'input_ids': encodings['input_ids'][random_indices].clone(),
#     'attention_mask': encodings['attention_mask'][random_indices].clone(),
#     'labels': encodings['labels'][random_indices].clone(),
# }


class SentimentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, idx):
        item = {key: tensor[idx].clone().detach() for key, tensor in self.encodings.items()}
        return {
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'labels': item['labels']
        }
    
# Fine-tuning modelu
num_epochs = 10  #_________________________________________________________________________________
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
all_fold_results = []

best_f1_score = 0.0

columns = ['fold', 'epoch', 'train_loss', 'eval_loss', 'accuracy', 'precision', 'recall', 'f1', 'confusion_matrix']
df = pd.DataFrame(columns=columns)

for fold, (train_indices, test_indices) in enumerate(skf.split(encodings['input_ids'], encodings['labels'])):    
    # Definícia lokálnych adries pre model a tokenizer
    local_model_path = "../../20240301 model/models/PureBPE_epoch_9_encodings_2"
    tag_slovnik = {'NEGATIVE': 0,'NEUTRAL': 1,'POSITIVE': 2}

    # Vytvorenie štruktúry label_list s názvami tried
    label_list = [None] * len(tag_slovnik)
    for tag, index in tag_slovnik.items():
        label_list[index] = tag

    # Priradenie názvov tried k jednotlivým indexom
    model = RobertaForSequenceClassification.from_pretrained(local_model_path, num_labels=len(tag_slovnik))
    model.config.id2label = {i: label for i, label in enumerate(label_list)}
    model.config.label2id = {label: i for i, label in enumerate(label_list)}
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01) #5e-6
    fold_results = {'fold': fold + 1}
    
    train_encodings = {
        'input_ids': encodings['input_ids'][train_indices].clone(), 
        'attention_mask': encodings['attention_mask'][train_indices].clone(), 
        'labels': encodings['labels'][train_indices].clone(), 
    }

    val_encodings = {
        'input_ids': encodings['input_ids'][test_indices].clone(), 
        'attention_mask': encodings['attention_mask'][test_indices].clone(), 
        'labels': encodings['labels'][test_indices].clone(), 
    }
    
    train_dataset = SentimentDataset(train_encodings)
    val_dataset = SentimentDataset(val_encodings)

    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    best_val_loss = float('inf')  # Inicializácia pre uchovávanie najlepšej evaluačnej straty
    
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=True)
        for step, batch in loop:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            loop.set_description(f"Fold {fold+1} - Epoch {epoch+1}/{num_epochs} - Train")
            loop.set_postfix(loss=loss.item())
            
            train_loss += loss.item()
        avg_train_loss = train_loss / len(train_dataloader)
        
        model.eval()

        y_true = []
        y_pred = []

        f1s = []
        pres = []
        recs = []
        accs = []
        
        val_loss = 0.0
        loop = tqdm(enumerate(val_dataloader), total=len(val_dataloader), leave=True)
        for step, batch in loop:
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                
                loop.set_description(f"Fold {fold+1} - Epoch {epoch+1}/{num_epochs} - Validation")
                loop.set_postfix(loss=loss.item())
                
                val_loss += loss.item()

                # Metriky

                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=1).cpu().numpy()  # Konverzia na numpy array pre confusion_matrix
                references = labels.cpu().numpy()
                y_pred.extend(predictions)
                y_true.extend(references)

        avg_val_loss = val_loss / len(val_dataloader)
        
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
        recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
        f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
        
        con_mat = confusion_matrix(y_true, y_pred).tolist()

        # Výpis informácií o priemerových hodnotách
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Confusion matrix: {con_mat}")
                
        # Vytvorenie nového riadku pre DataFrame
        new_row = {
            'fold': fold + 1,
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'eval_loss': avg_val_loss,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': con_mat
        }

        new_df = pd.DataFrame([new_row])

        # Pripojenie nového riadku k existujúcemu DataFrame
        df = pd.concat([df, new_df], ignore_index=True)

        # Uloženie výsledkov do súboru Excel po každej epoche
        df.to_excel(f"vysledky1905/m19/PureBPE_folds_loss_epoch_{epoch+1}.xlsx")

        # Uloženie najlepšieho modelu
        if f1 > best_f1_score:
            best_model = copy.deepcopy(model)

        gc.collect()
    
df.to_excel("vysledky1905/m19/pureBPE_folds_loss.xlsx") 

model_save_path = f'models1905/m19/PureBPE_sentiment'
best_model.save_pretrained(model_save_path)
